package com.duojuhe.websocket.socket;

import com.duojuhe.common.constant.SystemConstants;
import com.duojuhe.common.utils.idgenerator.SnowFlakeUtil;
import com.duojuhe.common.utils.token.TokenUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.tomcat.util.net.AbstractEndpoint;
import org.apache.tomcat.util.net.NioEndpoint;
import org.apache.tomcat.util.net.SocketWrapperBase;
import org.apache.tomcat.websocket.WsSession;
import org.apache.tomcat.websocket.server.WsFrameServer;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;

import javax.servlet.http.HttpServletRequest;
import java.io.EOFException;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.SocketTimeoutException;
import java.util.Map;
import java.util.concurrent.RejectedExecutionException;

/**
 * Socket工具类
 */
@Slf4j
public class SocketUtil {
    /**
     * socket连接的终端id
     */
    public static final String USER_ID = "userId";

    /**
     * socket连接的终端ip
     */
    public static final String CLIENT_IP = "clientIp";

    /**
     * socket连接的终端类型标识
     */
    public static final String CLIENT_TYPE = "clientType";

    /**
     * 心跳内容标识
     */
    public static final String SOCKET_PING = "ping";
    /**
     * 心跳响应标识
     */
    public static final String SOCKET_PONG = "pong";


    /**
     * 获取token
     *
     * @param session WebSocketSession
     * @return IP地址
     */
    public static String getUserToken(WebSocketSession session) {
        String token = (String) session.getAttributes().get(TokenUtils.TOKEN_HEADER);
        if (StringUtils.isBlank(token)) {
            return "";
        }
        return token;
    }

    /**
     * 获取客户端IP
     *
     * @param session WebSocketSession
     * @return IP地址
     */
    public static String getClientIp(WebSocketSession session) {
        if (session.getRemoteAddress() == null) {
            return SystemConstants.SYSTEM_IP;
        }
        String ip = (String) session.getAttributes().get(CLIENT_IP);
        if (StringUtils.isBlank(ip)) {
            return session.getRemoteAddress().getHostName();
        }
        return ip;
    }


    /**
     * 获得终端连接类型
     *
     * @param request
     * @return
     */
    public static String getClientTypeByRequest(HttpServletRequest request) {
        //ClientType
        if (StringUtils.isBlank(request.getHeader(CLIENT_TYPE))) {
            return request.getParameter(CLIENT_TYPE);
        } else {
            return request.getHeader(CLIENT_TYPE);
        }
    }

    /**
     * 获取clientType
     *
     * @param session WebSocketSession
     * @return 连接终端类型
     */
    public static String getClientType(WebSocketSession session) {
        String clientType = (String) session.getAttributes().get(CLIENT_TYPE);
        if (StringUtils.isBlank(clientType)) {
            return "";
        }
        return clientType;
    }

    /**
     * 获取用户id
     *
     * @param session WebSocketSession
     * @return IP地址
     */
    public static String getUserId(WebSocketSession session) {
        String userId = (String) session.getAttributes().get(USER_ID);
        if (StringUtils.isBlank(userId)) {
            return SnowFlakeUtil.getId();
        }
        return userId;
    }


    /**
     * 发送消息给所有已连接客户端
     */
    public static void sendMessageToAllSessionUser(SocketMessage message) {
        for (Map.Entry<String, SocketSessionUser> entry : SocketSessionCache.getAllSessionUser().entrySet()) {
            SocketUtil.sendMessageBySocketSessionUser(entry.getValue(), message.toString());
        }
    }

    /**
     * 给指定用户发送消息
     *
     * @param userId
     * @param message
     */
    public static void sendMessageToUserId(String userId, SocketMessage message) {
        for (SocketSessionUser sessionUser : SocketSessionCache.getSocketSessionUserListByUserId(userId)) {
            SocketUtil.sendMessageBySocketSessionUser(sessionUser, message.toString());
        }
    }


    /**
     * 响应消息给客户端
     */
    public static void sendMessageBySession(WebSocketSession session, SocketMessage message) {
        SocketSessionUser socketSessionUser = SocketSessionCache.getSessionUserBySessionId(session.getId());
        if (socketSessionUser != null) {
            sendMessageBySocketSessionUser(socketSessionUser, message.toString());
        }
    }


    /**
     * 发送消息给客户端，以队列的形式
     *
     * @param socketSessionUser
     * @param message
     */
    private static void sendMessageBySocketSessionUser(SocketSessionUser socketSessionUser, String message) {
        WebSocketSession session = socketSessionUser.getSession();
        if (session.isOpen()) {
            try {
                socketSessionUser.getThreadPoolExecutor().execute(() -> {
                    sendMessage(session, message);
                });
            } catch (RejectedExecutionException e) {
                //在客户端接收慢的时候，会出现队列已满的情况
                log.error("消息发送给客户端失败 队列已满,sessionId={}", session.getId());
                if (socketSessionUser.getSession().isOpen()) {
                    try {
                        socketSessionUser.getSession().close(CloseStatus.SERVICE_RESTARTED);
                    } catch (IOException e1) {
                        log.error("消息发送给客户端失败 队列已满，尝试关闭socket失败", e);
                    }
                }
            } catch (Exception e) {
                //其他未知异常
                log.error("消息发送给客户端失败，未知异常,sessionId={}", session.getId());
            }
        } else {
            if (!socketSessionUser.getThreadPoolExecutor().isShutdown()) {
                socketSessionUser.getThreadPoolExecutor().shutdownNow();
            }
        }
    }

    /**
     * 封装session发送消息
     */
    private static void sendMessage(WebSocketSession session, String message) {
        if (session != null && session.isOpen()) {
            try {
                session.sendMessage(new TextMessage(message));
            } catch (Exception e) {
                log.error("socket send fail,sessionId={}, message={}", session.getId(), message, e);
            }
        }
    }


    /**
     * 处理session传输异常
     *
     * @param session   具体的session
     * @param exception 异常信息
     */
    public static void handleTransportError(WebSocketSession session, Throwable exception) {
        if (exception instanceof EOFException) {
            log.error("a session transport error, sessionId={}, error={}", session.getId(), "EOFException");
        } else if (exception instanceof SocketTimeoutException) {
            log.error("a session transport error, sessionId={}, error={}", session.getId(), "SocketTimeoutException");
        } else {
            log.error("a session transport error, sessionId={}", session.getId(), exception);
        }
    }

    /**
     * 解决WsSession内存泄露问题
     */
    public static void release(WebSocketSession session) {
        try {
            StandardWebSocketSession standardWebSocketSession = (StandardWebSocketSession) session;
            WsSession wsSession = standardWebSocketSession.getNativeSession(WsSession.class);

            Field wsFrame = WsSession.class.getDeclaredField("wsFrame");
            wsFrame.setAccessible(true);
            WsFrameServer wsFrameServer = (WsFrameServer) wsFrame.get(wsSession);

            Field socketWrapper = WsFrameServer.class.getDeclaredField("socketWrapper");
            socketWrapper.setAccessible(true);
            SocketWrapperBase socketWrapperBase = (SocketWrapperBase) socketWrapper.get(wsFrameServer);

            Field endpoint = SocketWrapperBase.class.getDeclaredField("endpoint");
            endpoint.setAccessible(true);
            AbstractEndpoint abstractEndpoint = (AbstractEndpoint) endpoint.get(socketWrapperBase);
            NioEndpoint nioEndpoint = (NioEndpoint) abstractEndpoint;

            AbstractEndpoint.Handler handler = nioEndpoint.getHandler();
            handler.release(socketWrapperBase);

        } catch (Exception e) {
            log.error("web socket release error", e);
        }
    }

}
